"""
记忆遗忘任务
每5分钟进行一次遗忘检查，根据不同的遗忘阶段删除记忆
"""

import time
import random
from typing import List

from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.manager.async_task_manager import AsyncTask

logger = get_logger("memory_forget_task")


class MemoryForgetTask(AsyncTask):
    """记忆遗忘任务，每5分钟执行一次"""

    def __init__(self):
        # 每5分钟执行一次（300秒）
        super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)

    async def run(self):
        """执行遗忘检查"""
        try:
            current_time = time.time()
            logger.info("[记忆遗忘] 开始遗忘检查...")

            # 执行4个阶段的遗忘检查
            await self._forget_stage_1(current_time)
            await self._forget_stage_2(current_time)
            await self._forget_stage_3(current_time)
            await self._forget_stage_4(current_time)

            logger.info("[记忆遗忘] 遗忘检查完成")
        except Exception as e:
            logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)

    async def _forget_stage_1(self, current_time: float):
        """
        第一次遗忘检查：
        搜集所有：记忆还未被遗忘检查过（forget_times=0），且已经是30分钟之外的记忆
        取count最高25%和最低25%，删除，然后标记被遗忘检查次数为1
        """
        try:
            # 30分钟 = 1800秒
            time_threshold = current_time - 1800

            # 查询符合条件的记忆：forget_times=0 且 end_time < time_threshold
            candidates = list(
                ChatHistory.select().where((ChatHistory.forget_times == 0) & (ChatHistory.end_time < time_threshold))
            )

            if not candidates:
                logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
                return

            logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")

            # 按count排序
            candidates.sort(key=lambda x: x.count, reverse=True)

            # 计算要删除的数量（最高25%和最低25%）
            total_count = len(candidates)
            delete_count = int(total_count * 0.25)  # 25%

            if delete_count == 0:
                logger.debug("[记忆遗忘-阶段1] 删除数量为0，跳过")
                return

            # 选择要删除的记录（处理count相同的情况：随机选择）
            to_delete = []
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))

            # 去重（避免重复删除），使用id去重
            seen_ids = set()
            unique_to_delete = []
            for record in to_delete:
                if record.id not in seen_ids:
                    seen_ids.add(record.id)
                    unique_to_delete.append(record)
            to_delete = unique_to_delete

            # 删除记录并更新forget_times
            deleted_count = 0
            for record in to_delete:
                try:
                    record.delete_instance()
                    deleted_count += 1
                except Exception as e:
                    logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")

            # 更新剩余记录的forget_times为1
            to_delete_ids = {r.id for r in to_delete}
            remaining = [r for r in candidates if r.id not in to_delete_ids]
            if remaining:
                # 批量更新
                ids_to_update = [r.id for r in remaining]
                ChatHistory.update(forget_times=1).where(ChatHistory.id.in_(ids_to_update)).execute()

            logger.info(
                f"[记忆遗忘-阶段1] 完成：删除了 {deleted_count} 条记忆，更新了 {len(remaining)} 条记忆的forget_times为1"
            )

        except Exception as e:
            logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)

    async def _forget_stage_2(self, current_time: float):
        """
        第二次遗忘检查：
        搜集所有：记忆遗忘检查为1，且已经是8小时之外的记忆
        取count最高7%和最低7%，删除，然后标记被遗忘检查次数为2
        """
        try:
            # 8小时 = 28800秒
            time_threshold = current_time - 28800

            # 查询符合条件的记忆：forget_times=1 且 end_time < time_threshold
            candidates = list(
                ChatHistory.select().where((ChatHistory.forget_times == 1) & (ChatHistory.end_time < time_threshold))
            )

            if not candidates:
                logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
                return

            logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")

            # 按count排序
            candidates.sort(key=lambda x: x.count, reverse=True)

            # 计算要删除的数量（最高7%和最低7%）
            total_count = len(candidates)
            delete_count = int(total_count * 0.07)  # 7%

            if delete_count == 0:
                logger.debug("[记忆遗忘-阶段2] 删除数量为0，跳过")
                return

            # 选择要删除的记录
            to_delete = []
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))

            # 去重
            to_delete = list(set(to_delete))

            # 删除记录
            deleted_count = 0
            for record in to_delete:
                try:
                    record.delete_instance()
                    deleted_count += 1
                except Exception as e:
                    logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")

            # 更新剩余记录的forget_times为2
            to_delete_ids = {r.id for r in to_delete}
            remaining = [r for r in candidates if r.id not in to_delete_ids]
            if remaining:
                ids_to_update = [r.id for r in remaining]
                ChatHistory.update(forget_times=2).where(ChatHistory.id.in_(ids_to_update)).execute()

            logger.info(
                f"[记忆遗忘-阶段2] 完成：删除了 {deleted_count} 条记忆，更新了 {len(remaining)} 条记忆的forget_times为2"
            )

        except Exception as e:
            logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)

    async def _forget_stage_3(self, current_time: float):
        """
        第三次遗忘检查：
        搜集所有：记忆遗忘检查为2，且已经是48小时之外的记忆
        取count最高5%和最低5%，删除，然后标记被遗忘检查次数为3
        """
        try:
            # 48小时 = 172800秒
            time_threshold = current_time - 172800

            # 查询符合条件的记忆：forget_times=2 且 end_time < time_threshold
            candidates = list(
                ChatHistory.select().where((ChatHistory.forget_times == 2) & (ChatHistory.end_time < time_threshold))
            )

            if not candidates:
                logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
                return

            logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")

            # 按count排序
            candidates.sort(key=lambda x: x.count, reverse=True)

            # 计算要删除的数量（最高5%和最低5%）
            total_count = len(candidates)
            delete_count = int(total_count * 0.05)  # 5%

            if delete_count == 0:
                logger.debug("[记忆遗忘-阶段3] 删除数量为0，跳过")
                return

            # 选择要删除的记录
            to_delete = []
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))

            # 去重
            to_delete = list(set(to_delete))

            # 删除记录
            deleted_count = 0
            for record in to_delete:
                try:
                    record.delete_instance()
                    deleted_count += 1
                except Exception as e:
                    logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")

            # 更新剩余记录的forget_times为3
            to_delete_ids = {r.id for r in to_delete}
            remaining = [r for r in candidates if r.id not in to_delete_ids]
            if remaining:
                ids_to_update = [r.id for r in remaining]
                ChatHistory.update(forget_times=3).where(ChatHistory.id.in_(ids_to_update)).execute()

            logger.info(
                f"[记忆遗忘-阶段3] 完成：删除了 {deleted_count} 条记忆，更新了 {len(remaining)} 条记忆的forget_times为3"
            )

        except Exception as e:
            logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)

    async def _forget_stage_4(self, current_time: float):
        """
        第四次遗忘检查：
        搜集所有：记忆遗忘检查为3，且已经是7天之外的记忆
        取count最高2%和最低2%，删除，然后标记被遗忘检查次数为4
        """
        try:
            # 7天 = 604800秒
            time_threshold = current_time - 604800

            # 查询符合条件的记忆：forget_times=3 且 end_time < time_threshold
            candidates = list(
                ChatHistory.select().where((ChatHistory.forget_times == 3) & (ChatHistory.end_time < time_threshold))
            )

            if not candidates:
                logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
                return

            logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")

            # 按count排序
            candidates.sort(key=lambda x: x.count, reverse=True)

            # 计算要删除的数量（最高2%和最低2%）
            total_count = len(candidates)
            delete_count = int(total_count * 0.02)  # 2%

            if delete_count == 0:
                logger.debug("[记忆遗忘-阶段4] 删除数量为0，跳过")
                return

            # 选择要删除的记录
            to_delete = []
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
            to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))

            # 去重
            to_delete = list(set(to_delete))

            # 删除记录
            deleted_count = 0
            for record in to_delete:
                try:
                    record.delete_instance()
                    deleted_count += 1
                except Exception as e:
                    logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")

            # 更新剩余记录的forget_times为4
            to_delete_ids = {r.id for r in to_delete}
            remaining = [r for r in candidates if r.id not in to_delete_ids]
            if remaining:
                ids_to_update = [r.id for r in remaining]
                ChatHistory.update(forget_times=4).where(ChatHistory.id.in_(ids_to_update)).execute()

            logger.info(
                f"[记忆遗忘-阶段4] 完成：删除了 {deleted_count} 条记忆，更新了 {len(remaining)} 条记忆的forget_times为4"
            )

        except Exception as e:
            logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)

    def _handle_same_count_random(
        self, candidates: List[ChatHistory], delete_count: int, mode: str
    ) -> List[ChatHistory]:
        """
        处理count相同的情况，随机选择要删除的记录

        Args:
            candidates: 候选记录列表（已按count排序）
            delete_count: 要删除的数量
            mode: "high" 表示选择最高count的记录，"low" 表示选择最低count的记录

        Returns:
            要删除的记录列表
        """
        if not candidates or delete_count == 0:
            return []

        to_delete = []

        if mode == "high":
            # 从最高count开始选择
            start_idx = 0
            while start_idx < len(candidates) and len(to_delete) < delete_count:
                # 找到所有count相同的记录
                current_count = candidates[start_idx].count
                same_count_records = []
                idx = start_idx
                while idx < len(candidates) and candidates[idx].count == current_count:
                    same_count_records.append(candidates[idx])
                    idx += 1

                # 如果相同count的记录数量 <= 还需要删除的数量，全部选择
                needed = delete_count - len(to_delete)
                if len(same_count_records) <= needed:
                    to_delete.extend(same_count_records)
                else:
                    # 随机选择需要的数量
                    to_delete.extend(random.sample(same_count_records, needed))

                start_idx = idx

        else:  # mode == "low"
            # 从最低count开始选择
            start_idx = len(candidates) - 1
            while start_idx >= 0 and len(to_delete) < delete_count:
                # 找到所有count相同的记录
                current_count = candidates[start_idx].count
                same_count_records = []
                idx = start_idx
                while idx >= 0 and candidates[idx].count == current_count:
                    same_count_records.append(candidates[idx])
                    idx -= 1

                # 如果相同count的记录数量 <= 还需要删除的数量，全部选择
                needed = delete_count - len(to_delete)
                if len(same_count_records) <= needed:
                    to_delete.extend(same_count_records)
                else:
                    # 随机选择需要的数量
                    to_delete.extend(random.sample(same_count_records, needed))

                start_idx = idx

        return to_delete
